/*
 *
 *  * Copyright 2020 New Relic Corporation. All rights reserved.
 *  * SPDX-License-Identifier: Apache-2.0
 *
 */

package com.newrelic.agent.profile;

import com.newrelic.agent.Agent;
import com.newrelic.agent.ThreadService;
import com.newrelic.agent.instrumentation.InstrumentedClass;
import com.newrelic.agent.instrumentation.InstrumentedMethod;
import com.newrelic.agent.service.ServiceFactory;
import com.newrelic.agent.transport.DataSenderWriter;
import com.newrelic.agent.util.StackTraces;
import org.json.simple.JSONArray;

import java.io.IOException;
import java.io.Writer;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.zip.Deflater;

/**
 * Execution profile over a time period
 */
public class Profile implements IProfile {

    public static final int MAX_STACK_DEPTH = 300;
    /*
     * Since the Collector has a 1000000 byte limit on the Content-Length in the HttpServletRequest, restrict the size
     * of the profile to 60000 stack elements (~750000 bytes).
     */
    public static final int MAX_STACK_SIZE = 60000;

    /**
     * Collector has a 1000000 byte limit on the Content-Length in the HttpServletRequest
     */
    public static final int MAX_ENCODED_BYTES = 1000000;

    /**
     * The maximum size of the JSON payload itself (excluding data) is 114 bytes (4 Longs + 3 Integers). This means
     * that we need to take the collector max (1MB) and subtract our payload size to figure out how much data we
     * can include before potentially going over the limit.
     */
    public static final int MAX_ENCODED_DATA_BYTES = MAX_ENCODED_BYTES - 114;

    /**
     * The amount to trim a stack size by if as long as its encoded size is greater than MAX_ENCODED_BYTES.
     */
    public static final int STACK_TRIM = 10000;

    private long startTimeMillis = 0;
    private long endTimeMillis = 0;
    private int sampleCount = 0;
    private int totalThreadCount = 0;
    private int runnableThreadCount = 0;
    private Map<Long, Long> startThreadCpuTimes;
    private final ProfilerParameters profilerParameters;
    private final Map<ThreadType, ProfileTree> profileTrees = new HashMap<>();

    public Profile(ProfilerParameters parameters) {
        this.profilerParameters = parameters;
    }

    private Map<Long, Long> getThreadCpuTimes() {
        ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
        if (!(threadMXBean.isThreadCpuTimeSupported() && threadMXBean.isThreadCpuTimeEnabled())) {
            return Collections.emptyMap();
        }

        HashMap<Long, Long> cpuTimes = new HashMap<>();
        for (long id : threadMXBean.getAllThreadIds()) {
            cpuTimes.put(id, threadMXBean.getThreadCpuTime(id));
        }
        return cpuTimes;
    }

    @Override
    public ProfileTree getProfileTree(ThreadType threadType) {
        ProfileTree profileTree = profileTrees.get(threadType);
        if (profileTree == null) {
            profileTree = new ProfileTree();
            profileTrees.put(threadType, profileTree);
        }
        return profileTree;
    }

    /**
     * Subclasses may override.
     */
    @Override
    public void start() {
        startTimeMillis = System.currentTimeMillis();
        startThreadCpuTimes = getThreadCpuTimes();

        ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
        if (!threadMXBean.isThreadCpuTimeSupported()) {
            Agent.LOG.info("Profile unable to record CPU time: Thread CPU time measurement is not supported");
        } else if (!threadMXBean.isThreadCpuTimeEnabled()) {
            Agent.LOG.info("Profile unable to record CPU time: Thread CPU time measurement is not enabled");
        }
    }

    /**
     * Subclasses may override.
     */
    @Override
    public void end() {
        endTimeMillis = System.currentTimeMillis();

        Map<Long, Long> endThreadCpuTimes = getThreadCpuTimes();

        ThreadService threadService = ServiceFactory.getThreadService();
        final Set<Long> agentThreadIds = threadService.getAgentThreadIds();

        for (Entry<Long, Long> entry : endThreadCpuTimes.entrySet()) {
            Long startTime = startThreadCpuTimes.get(entry.getKey());
            if (startTime == null) {
                startTime = 0l;
            }
            long cpuTime = TimeUnit.MILLISECONDS.convert(entry.getValue() - startTime, TimeUnit.NANOSECONDS);

            ProfileTree tree;
            if (agentThreadIds.contains(entry.getKey())) {
                tree = getProfileTree(ThreadType.BasicThreadType.AGENT);
            } else {
                tree = getProfileTree(ThreadType.BasicThreadType.OTHER);
            }
            tree.incrementCpuTime(cpuTime);
        }

        int stackCount = getCallSiteCount();
        String msg = MessageFormat.format("Profile size is {0} stack elements", stackCount);
        Agent.LOG.info(msg);
        if (stackCount > MAX_STACK_SIZE) {
            Agent.LOG.info(MessageFormat.format("Trimmed profile size by {0} stack elements", trim(stackCount
                    - MAX_STACK_SIZE, stackCount)));
        }
    }

    /**
     * Use the loaded classes to mark all of the {@link ProfiledMethod}s which are instrumented using our method
     * annotations.
     *
     * @see InstrumentedClass
     * @see InstrumentedMethod
     */
    @Override
    public void markInstrumentedMethods() {
        try {
            doMarkInstrumentedMethods();
        } catch (Throwable ex) {
            String msg = MessageFormat.format("Error marking instrumented methods {0}", ex);
            if (Agent.LOG.isLoggable(Level.FINEST)) {
                Agent.LOG.log(Level.FINEST, msg, ex);
            } else {
                Agent.LOG.finer(msg);
            }
        }
    }

    private void doMarkInstrumentedMethods() {
        Class<?>[] allLoadedClasses = ServiceFactory.getCoreService().getInstrumentation().getAllLoadedClasses();
        Map<String, Class<?>> classMap = new HashMap<>();
        for (Class<?> clazz : allLoadedClasses) {
            classMap.put(clazz.getName(), clazz);
        }
        for (ProfileTree tree : profileTrees.values()) {
            tree.setMethodDetails(classMap);
        }
    }

    /**
     * For testing.
     */
    @Override
    public int trimBy(int limit) {
        return trim(limit, getCallSiteCount());
    }

    /**
     * Reduce the size of the profile by removing segments with the lowest call count (first priority) and highest depth
     * in the tree (second priority).
     *
     * @param limit the maximum number of segments to remove
     */
    private int trim(int limit, int stackCount) {
        ProfileSegmentSort[] segments = getSortedSegments(stackCount);
        int count = 0;
        for (ProfileSegmentSort segment : segments) {
            if (count >= limit) {
                break;
            }
            segment.remove();
            count++;
        }
        return count;
    }

    /**
     * Get a sorted array of all segments in the profile.
     */
    private ProfileSegmentSort[] getSortedSegments(int stackCount) {
        ProfileSegmentSort[] segments = new ProfileSegmentSort[stackCount];
        int index = 0;
        for (ProfileTree profileTree : profileTrees.values()) {
            for (ProfileSegment rootSegment : profileTree.getRootSegments()) {
                index = addSegment(rootSegment, null, 1, segments, index);
            }
        }
        Arrays.sort(segments);
        return segments;
    }

    private int addSegment(ProfileSegment segment, ProfileSegment parent, int depth, ProfileSegmentSort[] segments,
            int index) {
        ProfileSegmentSort segSort = new ProfileSegmentSort(segment, parent, depth);
        segments[index++] = segSort;
        for (ProfileSegment child : segment.getChildren()) {
            index = addSegment(child, segment, ++depth, segments, index);
        }
        return index;
    }

    /**
     * Get the number of distinct method invocation nodes in the profile.
     */
    private int getCallSiteCount() {
        int count = 0;
        for (ProfileTree profileTree : profileTrees.values()) {
            count += profileTree.getCallSiteCount();
        }
        return count;
    }

    @Override
    public Long getProfileId() {
        return profilerParameters.getProfileId();
    }

    @Override
    public ProfilerParameters getProfilerParameters() {
        return profilerParameters;
    }

    @Override
    public void beforeSampling() {
        sampleCount++;
    }

    @Override
    public int getSampleCount() {
        return sampleCount;
    }

    @Override
    public final long getStartTimeMillis() {
        return startTimeMillis;
    }

    @Override
    public final long getEndTimeMillis() {
        return endTimeMillis;
    }

    @Override
    public void writeJSONString(Writer out) throws IOException {
        JSONArray.writeJSONString(Arrays.asList(profilerParameters.getProfileId(), startTimeMillis, endTimeMillis,
                sampleCount, getData(out), totalThreadCount, runnableThreadCount), out);
    }

    private Object getData(Writer out) {
        Object result = DataSenderWriter.getJsonifiedOptionallyCompressedEncodedString(profileTrees, out,
                Deflater.BEST_SPEED, MAX_ENCODED_DATA_BYTES);

        // trim if necessary until encoded/compressed size is under the collector's threshold
        int maxStack = MAX_STACK_SIZE;
        while (result == null && maxStack > 0) {
            maxStack -= STACK_TRIM;
            int stackCount = getCallSiteCount();
            trim(stackCount - maxStack, stackCount);
            result = DataSenderWriter.getJsonifiedOptionallyCompressedEncodedString(profileTrees, out,
                    Deflater.BEST_SPEED, MAX_ENCODED_DATA_BYTES);
        }

        if (result != null && DataSenderWriter.isCompressingWriter(out)) {
            String msg = MessageFormat.format("Profile serialized size = {0} bytes", result.toString().length());
            Agent.LOG.info(msg);
        }
        return result;
    }

    private void incrementThreadCounts(boolean runnable) {
        totalThreadCount++;
        if (runnable) {
            runnableThreadCount++;
        }
    }

    private boolean shouldScrubStack(ThreadType type) {
        if (ThreadType.BasicThreadType.AGENT.equals(type)) {
            return false;
        }
        if (profilerParameters.isProfileAgentThreads()) {
            return false;
        }
        return true;
    }

    /**
     * Subclasses may override.
     */
    @Override
    public void addStackTrace(long threadId, boolean runnable, ThreadType type, StackTraceElement... stackTrace) {
        if (stackTrace.length < 2) {
            return;
        }

        incrementThreadCounts(runnable);

        List<StackTraceElement> stackTraceList;
        if (shouldScrubStack(type)) {
            stackTraceList = StackTraces.scrubAndTruncate(Arrays.asList(stackTrace), 0);
        } else {
            stackTraceList = Arrays.asList(stackTrace);
        }
        List<StackTraceElement> result = new ArrayList<>(stackTraceList);

        // the stack traces we get start with the leaves, not the roots. flip them
        Collections.reverse(result);

        getProfileTree(type).addStackTrace(result, runnable);
    }

    /**
     * A class to sort profile segments in order of lowest runnable call count (first) and highest depth in the stack
     * (second).
     *
     * Note: this class has a natural ordering that is inconsistent with equals.
     */
    private static class ProfileSegmentSort implements Comparable<ProfileSegmentSort> {

        private final ProfileSegment segment;
        private final ProfileSegment parent;
        private final int depth;

        private ProfileSegmentSort(ProfileSegment segment, ProfileSegment parent, int depth) {
            super();
            this.segment = segment;
            this.parent = parent;
            this.depth = depth;
        }

        void remove() {
            if (parent != null) {
                parent.removeChild(segment.getMethod());
            }
        }

        @Override
        public String toString() {
            return segment.toString();
        }

        @Override
        public int compareTo(ProfileSegmentSort other) {
            int thisCount = segment.getRunnableCallCount();
            int otherCount = other.segment.getRunnableCallCount();
            if (thisCount == otherCount) {
                return (depth > other.depth ? -1 : (depth == other.depth ? 0 : 1));
            }
            return thisCount > otherCount ? 1 : -1;
        }
    }

}
